# This code is part of a Qiskit project.
#
# (C) Copyright IBM 2025.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""
Entanglement Concentration
"""
from __future__ import annotations

import warnings
import os
import json

import numpy as np

from qiskit import QuantumCircuit
from qiskit.quantum_info import Operator, Statevector
from qiskit.circuit import ParameterVector

from ..utils import algorithm_globals


def entanglement_concentration_data(
    training_size: int,
    test_size: int,
    n: int,
    *,
    mode: str = "easy",
    one_hot: bool = True,
    include_sample_total: bool = False,
    sampling_method: str = "cardinal",
    class_labels: list | None = None,
    formatting: str = "ndarray",
) -> (
    tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]
    | tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]
    | tuple[list[Statevector], np.ndarray, list[Statevector], np.ndarray]
    | tuple[list[Statevector], np.ndarray, list[Statevector], np.ndarray, np.ndarray]
):
    r"""
    Generates a dataset that comprises of Quantum States with two different
    amounts of Concentration Of Entanglement (CE) and their corresponding class labels.
    These states are generated by the effect of two different pre-trained ansatz
    on fully seperable input states according to the procedure outlined in [1]. Pre-trained
    data in courtesy of L Schatzki et el [3]. The datapoints can be fully separated using
    the SWAP test outlined in [2]. First, input states are randomly generated from a
    uniform distribution, using a sampling method determined by the ``sampling_method``
    argument. Next, based on the ``mode`` argument, two pre-trained circuits "A" and "B"
    are used for generating datapoints.


    CE can be interpreted as a measure of correlation between the different qubits.
    The ``mode`` argument supports two options. ``"easy"`` gives datapoints with high CE
    difference hence being easy to seperate. ``"hard"`` mode gives closer CE values.
    The user's classifiers can be benchmarked against these modes for their ability to
    separate the data into two classes based on CE.


    Current implementation supports only ``n`` values of 3 and 4.


    ``sampling_method`` argument supports two options. ``"isotropic"`` and ``"cardinal"``.
    Isotropic generates qubit states that are sampled randomly in the Bloch Sphere and takes the
    tensor product of all the qubits to build the input state. Cardinal generates only states
    that fall on the axes of the Bloch Sphere before taking the tensor product.


    **References:**

    [1] Schatzki L, Arrasmith A, Coles PJ, Cerezo M. Entangled datasets for quantum machine learning.
    arXiv preprint. 2021 Sep; arXiv:2109.03400.
    https://arxiv.org/abs/2109.03400

    [2] Beckey JL, Gigena N, Coles PJ, Cerezo M. Computable and operationally meaningful multipartite
    entanglement measures. Physical Review Letters. 2021 Sep 27; 127(14):140501.
    doi:10.1103/PhysRevLett.127.140501
    https://doi.org/10.1103/PhysRevLett.127.140501

    [3] Schatzki L. NTangled Datasets - Hardware Efficient [dataset]. GitHub.
    2022 Mar 2 (commit f3a68ff).
    https://github.com/LSchatzki/NTangled_Datasets/tree/main/Hardware_Efficient


    Parameters:
        training_size : Number of training samples per class.
        test_size :  Number of testing samples per class.
        n : Number of qubits (dimension of the feature space). Current implementation
            supports only 3, 4 and 8
        mode :
            Choices are:

                * ``"easy"``: uses CE values 0.18 and 0.40 for n = 3 and 0.12 and 0.43 for n = 4
                * ``"hard"``: uses CE values 0.28 and 0.40 for n = 3 and 0.22 and 0.34 for n = 4

            Default is ``"easy"``.
        one_hot : If True, returns labels in one-hot format. Default is True.
        include_sample_total : If True, the function also returns the total number
            of accepted samples. Default is False.
        sampling_method: The method used to generate input states.
            Choices are:

                * ``"isotropic"``: samples qubit states uniformly in the bloch sphere
                * ``"cardinal"``: samples qubit states out of the 6 axes of bloch sphere

            Default is ``"cardinal"``.
        class_labels : Custom labels for the two classes when one-hot is not enabled.
            If not provided, the labels default to ``0`` and ``+1``
        formatting: The format in which datapoints are given.
            Choices are:

                * ``"ndarray"``: gives a numpy array of shape (n_points, 2**n_qubits, 1)
                * ``"statevector"``: gives a python list of Statevector objects

            Default is ``"ndarray"``.

    Returns:
        Tuple
        containing the following:

        * **training_features** : ``np.ndarray`` | ``qiskit.quantum_info.Statevector``
        * **training_labels** : ``np.ndarray``
        * **testing_features** : ``np.ndarray`` | ``qiskit.quantum_info.Statevector``
        * **testing_labels** : ``np.ndarray``

        If ``include_sample_total=True``, a fifth element (``np.ndarray``) is included
        that specifies the total number of accepted samples.
    """
    # Default Value
    if class_labels is None:
        class_labels = [0, 1]

    n_points = training_size + test_size

    # Errors
    if training_size < 0:
        raise ValueError("Training size can't be less than 0")
    if test_size < 0:
        raise ValueError("Test size can't be less than 0")
    if n not in [3, 4, 8]:
        raise ValueError("Currently only 3, 4 and 8 qubits are supported")
    if mode not in {"easy", "hard"}:
        raise ValueError("Invalid mode. Must be 'easy' or 'hard'")
    if sampling_method not in {"isotropic", "cardinal"}:
        raise ValueError("Invalid sampling method. Must be 'isotropic' or 'cardinal'")
    if sampling_method == "cardinal" and n_points >= (6**n):
        raise ValueError(
            """Cardinal Sampling cannot generate a large number of unique 
            datapoints due to the limited number of combinations possible. 
            Try "isotropic" sampling method"""
        )
    if formatting not in {"statevector", "ndarray"}:
        raise ValueError(
            """Formatting must be "statevector" or "ndarray". Please check for 
            case sensitivity."""
        )

    # Warnings
    if sampling_method == "cardinal" and n_points > (3**n):
        warnings.warn(
            """Cardinal Sampling for large number of samples is not recommended 
            and can lead to an arbitrarily large generation time due to 
            repeating datapoints. Try "isotropic" sampling method""",
            UserWarning,
        )

    # Depth Settings
    depth = {(3, "easy"): (2, 6), (3, "hard"): (2, 5), (4, "easy"): (2, 6), (4, "hard"): (2, 5)}

    d_low, d_high = depth[(n, mode)]

    # Import Models
    qc_low = QuantumCircuit(n)
    qc_high = QuantumCircuit(n)

    params_low = ParameterVector("low", d_low * n * 3)
    _hardware_efficient_ansatz(qc_low, params_low, n, d_low)
    bound_qc_low = _assign_parameters(n, mode, "low", d_low, qc_low)

    params_high = ParameterVector("high", d_high * n * 3)
    _hardware_efficient_ansatz(qc_high, params_high, n, d_high)
    bound_qc_high = _assign_parameters(n, mode, "high", d_high, qc_high)

    # Convert them to Unitaries for batch processing
    u_low = Operator(bound_qc_low, input_dims=(2**n, 1), output_dims=(2**n, 1)).data
    u_high = Operator(bound_qc_high, input_dims=(2**n, 1), output_dims=(2**n, 1)).data

    # Sampling Input States
    if sampling_method == "isotropic":
        psi_in = _isotropic(n, n_points)
    else:
        psi_in = _cardinal(n, n_points)

    a_features = u_low @ psi_in
    b_features = u_high @ psi_in

    if formatting == "ndarray":
        x_train = np.concatenate((a_features[:training_size], b_features[:training_size]), axis=0)
        x_test = np.concatenate((a_features[training_size:], b_features[training_size:]), axis=0)
    else:
        x_train = [Statevector(v) for v in a_features[:training_size, :, 0]] + [
            Statevector(v) for v in b_features[:training_size, :, 0]
        ]
        x_test = [Statevector(v) for v in a_features[training_size:, :, 0]] + [
            Statevector(v) for v in b_features[training_size:, :, 0]
        ]

    if one_hot:
        y_train = np.array([[1, 0]] * training_size + [[0, 1]] * training_size)
        y_test = np.array([[1, 0]] * test_size + [[0, 1]] * test_size)
    else:
        y_train = np.array([class_labels[0]] * training_size + [class_labels[1]] * training_size)
        y_test = np.array([class_labels[0]] * test_size + [class_labels[1]] * test_size)

    if include_sample_total:
        samples = np.array([n_points * 2])
        return (x_train, y_train, x_test, y_test, samples)

    return (x_train, y_train, x_test, y_test)


def _assign_parameters(
    n_qubits: int,
    mode: str,
    label: str,
    depth: int,
    qc: QuantumCircuit,
) -> QuantumCircuit:
    """Load pre‑trained parameters from ``models/`` and bind them."""

    file_path = os.path.join(
        os.path.dirname(__file__), "models", f"entanglement_{mode}_{label}_{n_qubits}qubits.json"
    )
    with open(file_path, "r") as weights_file:
        weights = np.array(json.load(weights_file)).flatten()

    expected = 3 * depth * n_qubits
    if len(weights) != expected:
        raise ValueError(
            """Parameter mismatch – please reinstall the latest 'qiskit-machine-learning' 
            package (or update the model files).""",
        )

    return qc.assign_parameters(weights, inplace=False)


def _hardware_efficient_ansatz(
    qc: QuantumCircuit, params: ParameterVector, n_qubits: int, depth: int
) -> None:
    """Append a hardware‑efficient ansatz layer‑by‑layer to the Quantum Circuit."""
    p_idx = 0

    for _ in range(depth):
        layer_start = p_idx

        for i in range(n_qubits):
            theta, phi, lamb = params[p_idx : p_idx + 3]
            p_idx += 3
            qc.u(theta, phi, lamb, i)

        for i in range(n_qubits // 2):
            qc.cz(2 * i, 2 * i + 1)

        reuse_idx = layer_start
        for i in range(n_qubits):
            theta, phi, lamb = params[reuse_idx : reuse_idx + 3]
            reuse_idx += 3
            qc.u(theta, phi, lamb, i)

        for i in range((n_qubits - 1) // 2):
            qc.cz(2 * i + 1, 2 * i + 2)


def _cardinal(n_qubits: int, n_points: int) -> np.ndarray:
    """Samples Qubit States in the axes of the Block Sphere
    and takes Kronecker product of those to create the input states

    Each product state is built from the six axis states
        |0>, |1>, |+>, |–>, |i>, |–i>
    chosen independently and uniformly for every qubit."""

    sqrt2 = np.sqrt(2.0)
    axis_states = (
        np.array(
            [
                [1.0, 0.0],  # |0>
                [0.0, 1.0],  # |1>
                [1.0, 1.0],  # |+>
                [1.0, -1.0],  # |–>
                [1.0, 1.0j],  # |i>
                [1.0, -1.0j],  # |–i>
            ],
            dtype=np.complex128,
        )
        / sqrt2
    )
    axis_states[0] *= sqrt2  # undo √2 for |0>
    axis_states[1] *= sqrt2  # undo √2 for |1>

    rng = algorithm_globals.random

    indices = rng.choice(6**n_qubits, size=n_points, replace=False)
    choices = np.empty((n_points, n_qubits), dtype=np.int8)
    for q in range(n_qubits - 1, -1, -1):
        choices[:, q] = indices % 6
        indices //= 6

    q_vectors = axis_states[choices]

    # Broadcast‑and‑Product evaluation of Kronecker products
    ints = np.arange(2**n_qubits, dtype=np.uint16)[:, None]
    bits = ((ints >> np.arange(n_qubits)) & 1).astype(np.int8)
    labels = np.flip(bits, axis=1)

    picked = np.take_along_axis(
        q_vectors[:, None, :, :],
        labels[None, :, :, None],
        axis=3,
    )

    amplitudes = picked.squeeze(-1).prod(axis=2)

    return amplitudes[:, :, None]


def _isotropic(n_qubits: int, n_points: int) -> np.ndarray:
    """Samples Qubit States uniformly in the Block Sphere"""

    rng = algorithm_globals.random

    # Uniform sampling on the sphere
    z = rng.uniform(-1, 1, size=(n_points, n_qubits))
    phi = rng.uniform(0, 2 * np.pi, size=(n_points, n_qubits))

    theta = np.arccos(z)
    cos = np.cos(theta / 2)
    sin = np.sin(theta / 2)

    q_vectors = np.stack([cos, sin * np.exp(1j * phi)], axis=-1)

    # Broadcast-and-Product
    ints = np.arange(2**n_qubits, dtype=np.uint16)[:, None]
    bits = ((ints >> np.arange(n_qubits)) & 1).astype(np.int8)
    labels = np.flip(bits, axis=1)
    picked = np.take_along_axis(
        q_vectors[:, None, :, :],
        labels[None, :, :, None],
        axis=3,
    )

    amplitudes = picked.squeeze(-1).prod(axis=2)

    return amplitudes[:, :, None]
